{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hackathon #6\n",
    "\n",
    "Written by Eleanor Quint\n",
    "\n",
    "Topics: \n",
    "- LSTM recurrent architecture\n",
    "- Attention\n",
    "\n",
    "This is all setup in a IPython notebook so you can run any code you want to experiment with. Feel free to edit any cell, or add some to run your own code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll start with our library imports...\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np                 # to use numpy arrays\n",
    "import tensorflow as tf            # to specify and run computation graphs\n",
    "import tensorflow_datasets as tfds # to load training data\n",
    "import matplotlib.pyplot as plt    # to visualize data and draw plots\n",
    "from tqdm import tqdm              # to track progress of loops"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### RNN/LSTM recap\n",
    "\n",
    "Recurrent neural networks (RNNs) are computation graphs with loops (i.e., not directed acyclic graphs). Because the backpropagation algorithm only works with DAGs, we have to unroll the RNN through time. Tensorflow provides code that handles this automatically.\n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png\" width=\"80%\">\n",
    "\n",
    "\n",
    "The most common RNN unit is the LSTM, depicted below:\n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png\" width=\"80%\">\n",
    "\n",
    "We can see that each unit takes 3 inputs and produces 3 outputs, two which are forwarded to the same unit at the next timestep and one true output, $h_t$ depicted coming out of the top of the cell.\n",
    "\n",
    "The upper right output going to the next timestep is the cell state. It carries long-term information between cells, and is calculated as: \n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-C.png\" width=\"80%\">\n",
    "\n",
    "where the first term uses the forget gate $f_t$ to decide to scale the previous state (potentially making it smaller to \"forget\" it), and the second term is the product of the update gate $i_t$ and the state update $\\tilde{C}_t$. Each of the forget and update gates are activated with sigmoid, so their range is (0,1).\n",
    "\n",
    "The true output and the second, lower output on the diagram are calculated by the output gate:\n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-o.png\" width=\"80%\">\n",
    "\n",
    "First, $o_t$ is calculated from the output of the previous timestep concatenated with the current input, but then it's mixed with the cell state to get the true output. Passing on this output to the next timestep as the hidden state gives the unit a kind of short term memory.\n",
    "\n",
    "(Images sourced from [Colah's Blog](http://colah.github.io/posts/2015-08-Understanding-LSTMs/))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Today, we're going to teach a recurrent model how to classify sentiment by inputting a sequence of words and asking the model to estimate what the sentiment is. The IMDB review dataset has the text of movie reviews and a label for whether the review is negative (zero) or positive (one). First, we'll load the dataset and then set up a preprocessor to turn the words into integers using `tf.keras.layers.experimental.preprocessing.TextVectorization`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LEN = 128\n",
    "MAX_TOKENS = 5000\n",
    "\n",
    "# load the text dataset\n",
    "ds = tfds.load('imdb_reviews')\n",
    "\n",
    "# Create TextVectorization layer\n",
    "vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(\n",
    "    max_tokens=MAX_TOKENS,\n",
    "    output_mode='int',\n",
    "    output_sequence_length=MAX_SEQ_LEN)\n",
    "\n",
    "# Use `adapt` to create a vocabulary mapping words to integers\n",
    "train_text = ds['train'].map(lambda x: x['text'])\n",
    "vectorize_layer.adapt(train_text)\n",
    "\n",
    "# Let's print out a batch to see what it looks like in text and in integers\n",
    "for batch in ds['train'].batch(1):\n",
    "    text = batch['text']\n",
    "    print(list(zip(text.numpy(), vectorize_layer(text).numpy())))\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You might notice that the integer representation of each sentence ends in zeroes. To ensure that each input the same shape, the end of the sequence is padded with zeros. This is determined by MAX_SEQ_LEN and should typically be as long as the longest sequence in your dataset. The MAX_TOKENS value helps the TextVectorizer to pre-allocate a number of values to assign tokens to.\n",
    "\n",
    "Now each word in the sequence is represented as an integer. However, this discrete representation fails to capture any semantic relationships between words. I.e., the model wouldn't know that \"crimson\" and \"scarlet\" are more similar than \"red\" and \"blue\". The solution is to learn an word embedding as the first part of the model to transform each integer into a relatively small, dense vector (as compared to a one-hot). Then, similar words will train to have similar embedded representations. A more sophisticated model might use a pre-trained word embedding like [BERT](https://blog.google/products/search/search-language-understanding-bert/).\n",
    "\n",
    "We'll use [tf.nn.embedding_lookup](https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup) to do this which we provide a trainable VOCAB_SIZE x EMBEDDING_SIZE matrix. This will learn an embedding from your dataset with training gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = len(vectorize_layer.get_vocabulary())\n",
    "EMBEDDING_SIZE = int(np.sqrt(VOCAB_SIZE))\n",
    "print(\"Vocab size is {} and is embedded into {} dimensions\".format(VOCAB_SIZE, EMBEDDING_SIZE))\n",
    "\n",
    "embedding_layer = tf.keras.layers.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow separates the declaration of RNNCells from the [RNNs](https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN) that run them. In the code below, we declare two [LSTM cells](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell) and pass them both to the RNN to be run together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the model\n",
    "cells = [tf.keras.layers.LSTMCell(256), tf.keras.layers.LSTMCell(64)]\n",
    "rnn = tf.keras.layers.RNN(cells)\n",
    "output_layer = tf.keras.layers.Dense(1)\n",
    "\n",
    "model = tf.keras.Sequential([vectorize_layer, embedding_layer, rnn, output_layer])\n",
    "\n",
    "# test a forward pass\n",
    "for batch in ds['train'].batch(32):\n",
    "    logits = model(batch['text'])\n",
    "    loss = tf.keras.losses.binary_crossentropy(tf.expand_dims(batch['label'], -1), logits, from_logits=True)\n",
    "    print(loss)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, to do sentiment analysis, we'll treat this as a classification problem. Because it's only a 2-class problem, we'll use binary cross-entropy and output only one value. Then the output is treated as class zero if it's $<0.5$ and class one if it's $\\geq0.5$. We can train this model in the usual way.\n",
    "\n",
    "### Attention\n",
    "\n",
    "One way to enhance the performance of sequential models is a mechanism called attention. Essentially, the idea is that in order to process the meaning of a token correctly, it might rely on tokens far away in the sequence. We're going to use Luong-style attention, which is a classic variety first employed in language translation. It's implemented by `tf.keras.layers.Attention`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll make a conv layer to produce the query and value tensors\n",
    "query_layer = tf.keras.layers.Conv1D(\n",
    "    filters=100,\n",
    "    kernel_size=4,\n",
    "    padding='same')\n",
    "value_layer = tf.keras.layers.Conv1D(\n",
    "    filters=100,\n",
    "    kernel_size=4,\n",
    "    padding='same')\n",
    "# Then they will be input to the Attention layer\n",
    "attention = tf.keras.layers.Attention()\n",
    "concat = tf.keras.layers.Concatenate()\n",
    "\n",
    "cells = [tf.keras.layers.LSTMCell(256), tf.keras.layers.LSTMCell(64)]\n",
    "rnn = tf.keras.layers.RNN(cells)\n",
    "output_layer = tf.keras.layers.Dense(1)\n",
    "\n",
    "for batch in ds['train'].batch(32):\n",
    "    text = batch['text']\n",
    "    embeddings = embedding_layer(vectorize_layer(text))\n",
    "    query = query_layer(embeddings)\n",
    "    value = value_layer(embeddings)\n",
    "    query_value_attention = attention([query, value])\n",
    "    print(\"Shape after attention is (batch, seq, filters):\", query_value_attention.shape)\n",
    "    attended_values = concat([query, query_value_attention])\n",
    "    print(\"Shape after concatenating is (batch, seq, filters):\", attended_values.shape)\n",
    "    logits = output_layer(rnn(attended_values))\n",
    "    loss = tf.keras.losses.binary_crossentropy(tf.expand_dims(batch['label'], -1), logits, from_logits=True)\n",
    "    print(loss)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this way we can produce a very similar RNN over a sequence, but incorporate more dynamic relationships betweeen distant tokens.\n",
    "\n",
    "### Homework\n",
    "\n",
    "Nothing. There is no homework in the hackathon this week. Thank you for bearing along with the early pace of this course in the midst of a once-in-a-century pandemic."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl_notebooks",
   "language": "python",
   "name": "dl_notebooks"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
